"""Functions for reading the compressed training/validation data records"""
import os
import sys
sys.path.append(os.path.dirname(__file__))
import numpy as np
from glob import glob
from tensorpack import dataflow
import pandas as pd
import zstandard as zstd
import msgpack
import msgpack_numpy
import pickle
import helper
from argoverse.data_loading.argoverse_forecasting_loader import ArgoverseForecastingLoader
from argoverse.map_representation.map_api import ArgoverseMap
msgpack_numpy.patch()


class ArgoverseDataFlowV2(dataflow.RNGDataFlow):
    """
    Data flow for argoverse dataset
    """

    def __init__(self, file_path: str, shuffle: bool = True, random_rotation: bool = False,
                 max_car_num: int = 60, freq: int = 10, use_interpolate: bool = False, 
                 lane_path: str = "/home/leo/particle/TrafficFluids/datasets", 
                 use_lane: bool = False, use_mask: bool = True):
        if not os.path.exists(file_path):
            raise Exception("Path does not exist.")

        self.afl = ArgoverseForecastingLoader(file_path)
        self.shuffle = shuffle
        self.random_rotation = random_rotation
        self.max_car_num = max_car_num
        self.freq = freq
        self.use_interpolate = use_interpolate
        self.am = ArgoverseMap()
        self.use_lane = use_lane
        self.use_mask = use_mask
        
        if self.use_lane:
            with open(os.path.join(lane_path, 'lanes.pkl'), 'rb') as f:
                self.lanes = pickle.load(f)
            with open(os.path.join(lane_path, 'lane_drct.pkl'), 'rb') as f:
                self.lane_drct = pickle.load(f)
        
        # self.lanes = dict()
        # self.lane_drct = dict()
        # for k in ('MIA', 'PIT'):
        #     print('Get {} lanes'.format(k))
        #     self.lanes[k] = np.concatenate(helper.get_all_lanes(k, self.am), axis=0)
        #     print('Expanding')
        #     self.lanes[k] = self._expand_dim(self.lanes[k])
        #     print('Geting directions')
        #     self.lane_drct[k] = helper.get_lane_direction(self.lanes[k], k, self.am)

    def __iter__(self):
        scene_idxs = np.arange(len(self.afl))
        
        if self.shuffle:
            self.rng.shuffle(scene_idxs)
        
        for scene in scene_idxs:
            
            if self.afl[scene].num_tracks > self.max_car_num:
                continue
            
            data, city = self.afl[scene].seq_df, self.afl[scene].city
            
            if self.use_lane:
                lane = np.concatenate(helper.get_lanes(data, city, self.am), axis=0)
                lane = self._expand_dim(lane)
                lane_drct = self._look_up_lane_drct(lane, city)
                lane_drct = self._expand_dim(lane_drct)
            else:
                lane = np.array([[0., 0., 0.]], dtype=np.float32)
                lane_drct = np.array([[0., 0., 0.]], dtype=np.float32)
            
            
            tstmps = data.TIMESTAMP.unique()
            tstmps.sort()
            
            if self.use_interpolate:
                # TODO: Threshhold, tunable
                data = self._expand_df(data, city)
                data = self._linear_interpolate(data)
            else:
                data = self._filter_imcomplete_data(data, tstmps, 50)
                
            data = self._calc_vel(data, self.freq)
            
            agent = data[data['OBJECT_TYPE'] == 'AGENT']['TRACK_ID'].values[0]
            
            if self.shuffle:
                self.rng.shuffle(data_idexs)
            
            car_mask = np.zeros((self.max_car_num, 1), dtype=np.float32)
            car_mask[:len(data.TRACK_ID.unique())] = 1.0
                
            feat_dict = {'city': city, 
                         'lane': lane, 
                         'lane_norm': lane_drct, 
                         'scene_idx': scene,  
                         'agent_id': agent, 
                         'car_mask': car_mask}
            
            pos_enc = [subdf[['X', 'Y']].values[np.newaxis,:] 
                       for _, subdf in data[data['TIMESTAMP'].isin(tstmps[:20])].groupby('TRACK_ID')]
            pos_enc = np.concatenate(pos_enc, axis=0)
            pos_enc = self._expand_dim(pos_enc)
            feat_dict['pos_2s'] = self._expand_particle(pos_enc, self.max_car_num, 0)
            
            vel_enc = [subdf[['vel_x', 'vel_y']].values[np.newaxis,:] 
                       for _, subdf in data[data['TIMESTAMP'].isin(tstmps[:20])].groupby('TRACK_ID')]
            vel_enc = np.concatenate(vel_enc, axis=0)
            vel_enc = self._expand_dim(vel_enc)
            feat_dict['vel_2s'] = self._expand_particle(vel_enc, self.max_car_num, 0)
                
            for t in range(30):
                pos = data[data['TIMESTAMP'] == tstmps[20 + t]][['X', 'Y']].values
                pos = self._expand_dim(pos)
                feat_dict['pos' + str(t)] = self._expand_particle(pos, self.max_car_num, 0)
                vel = data[data['TIMESTAMP'] == tstmps[20 + t]][['vel_x', 'vel_y']].values
                vel = self._expand_dim(vel)
                feat_dict['vel' + str(t)] = self._expand_particle(vel, self.max_car_num, 0)
                track_id =  data[data['TIMESTAMP'] == tstmps[20 + t]]['TRACK_ID'].values
                feat_dict['track_id' + str(t)] = self._expand_particle(track_id, self.max_car_num, 0, 'str')
                feat_dict['frame_id' + str(t)] = t
                
                # feat_dict['direction'] = helper.get_lane_direction(feat_dict['pos0'], city, self.am)
                
            yield feat_dict
            
    @classmethod
    def __expand_df_generator(cls, df, city_name):
        ids = df.TRACK_ID.unique()
        tstmps = df.TIMESTAMP.unique()
        for tstmp, sub_df in df.groupby('TIMESTAMP'):
            for idx in ids:
                if not idx in sub_df.TRACK_ID.values:
                    yield pd.DataFrame(dict(TIMESTAMP = [tstmp], TRACK_ID = [idx], X = [np.nan], Y = [np.nan], 
                                       CITY_NAME = [city_name], 
                                            OBJECT_TYPE = [df[df['TRACK_ID'] == idx]['OBJECT_TYPE'].iloc[0]]))
                else:
                    yield df[(df['TIMESTAMP'] == tstmp) & (df['TRACK_ID'] == idx)]

    @classmethod
    def _expand_df(cls, df, city_name):
        return pd.concat(cls.__expand_df_generator(df, city_name), axis=0)


    @classmethod
    def __calc_vel_generator(cls, df, freq=10):
        for idx, subdf in df.groupby('TRACK_ID'):
            sub_df = subdf.copy()
            sub_df[['vel_x', 'vel_y']] = sub_df[['X', 'Y']].diff() * freq
            yield sub_df.iloc[1:, :]

    @classmethod
    def _calc_vel(cls, df, freq=10):
        return pd.concat(cls.__calc_vel_generator(df, freq=freq), axis=0)
    
    @classmethod
    def _expand_dim(cls, ndarr, dtype=np.float32):
        return np.insert(ndarr, 2, values=0, axis=-1).astype(dtype)
    
    @classmethod
    def _linear_interpolate_generator(cls, data, col=['X', 'Y']):
        for idx, df in data.groupby('TRACK_ID'):
            sub_df = df.copy()
            sub_df[col] = sub_df[col].interpolate(limit_direction='both')
            yield sub_df
    
    @classmethod
    def _linear_interpolate(cls, data, col=['X', 'Y']):
        return pd.concat(cls._linear_interpolate_generator(data, col), axis=0)
    
    @classmethod
    def _filter_imcomplete_data(cls, data, tstmps, window=20):
        complete_id = list()
        for idx, subdf in data[data['TIMESTAMP'].isin(tstmps[:window])].groupby('TRACK_ID'):
            if len(subdf) == window:
                complete_id.append(idx)
        return data[data['TRACK_ID'].isin(complete_id)]
    
    @classmethod
    def _expand_particle(cls, arr, max_num, axis, value_type='int'):
        dummy_shape = list(arr.shape)
        dummy_shape[axis] = max_num - arr.shape[axis]
        dummy = np.zeros(dummy_shape)
        if value_type == 'str':
            dummy = np.array(['dummy' + str(i) for i in range(np.product(dummy_shape))]).reshape(dummy_shape)
        return np.concatenate([arr, dummy], axis=axis)
    
    def __look_up_lane_drct(self, lane, city):
        d = np.unique(self.lane_drct[city][np.equal(self.lanes[city], lane).all(axis=1)])
        return d[np.newaxis,:]
    
    def _look_up_lane_drct(self, lane, city):
        return np.concatenate([self.__look_up_lane_drct(l, city) for l in lane], axis=0)
        


def read_data(file_path=None,
              batch_size=1,
              random_rotation=False,
              repeat=False,
              shuffle_buffer=None,
              num_workers=1,
              cache_data=False, 
              **kwargs):
    # caching makes only sense if the data is finite
    if cache_data:
        if repeat == True:
            raise Exception("repeat must be False if cache_data==True")
        if random_rotation == True:
            raise Exception("random_rotation must be False if cache_data==True")
        if num_workers != 1:
            raise Exception("num_workers must be 1 if cache_data==True")

    df = ArgoverseDataFlowV2(
        file_path=file_path,
        random_rotation=random_rotation,
        shuffle=True if shuffle_buffer else False,
        **kwargs
    )

    if repeat:
        df = dataflow.RepeatedData(df, -1)

    if shuffle_buffer:
        df = dataflow.LocallyShuffleData(df, shuffle_buffer)

    if num_workers > 1:
        df = dataflow.MultiProcessRunnerZMQ(df, num_proc=num_workers)

    df = dataflow.BatchData(df, batch_size=batch_size, use_list=True)

    if cache_data:
        df = dataflow.CacheData(df)

    df.reset_state()
    return df


def read_data_val(file_path, **kwargs):
    return read_data(file_path=file_path,
                     batch_size=1,
                     repeat=False,
                     shuffle_buffer=None,
                     num_workers=1,
                     **kwargs)


def read_data_train(file_path, batch_size, random_rotation=True, **kwargs):
    return read_data(file_path=file_path,
                     batch_size=batch_size,
                     random_rotation=random_rotation,
                     shuffle_buffer=None,
                     **kwargs)
